// Copyright 2018 The Grafeas Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package vulnerability

import (
	"testing"

	pkgpb "github.com/grafeas/grafeas/proto/v1beta1/package_go_proto"
	vulnpb "github.com/grafeas/grafeas/proto/v1beta1/vulnerability_go_proto"
)

func TestValidateVulnerability(t *testing.T) {
	tests := []struct {
		desc     string
		v        *vulnpb.Vulnerability
		wantErrs bool
	}{
		{
			desc: "nil detail, want error(s)",
			v: &vulnpb.Vulnerability{
				Severity: vulnpb.Severity_CRITICAL,
				Details: []*vulnpb.Vulnerability_Detail{
					nil,
				},
			},
			wantErrs: true,
		},
		{
			desc: "invalid vulnerability detail, want error(s)",
			v: &vulnpb.Vulnerability{
				Severity: vulnpb.Severity_CRITICAL,
				Details: []*vulnpb.Vulnerability_Detail{
					&vulnpb.Vulnerability_Detail{},
				},
			},
			wantErrs: true,
		},
		{
			desc: "valid vulnerability, want success",
			v: &vulnpb.Vulnerability{
				Severity: vulnpb.Severity_CRITICAL,
				Details: []*vulnpb.Vulnerability_Detail{
					&vulnpb.Vulnerability_Detail{
						CpeUri:       "cpe:/o:debian:debian_linux:7",
						Package:      "debian",
						SeverityName: "LOW",
					},
				},
			},
			wantErrs: false,
		},
	}

	for _, tt := range tests {
		errs := ValidateVulnerability(tt.v)
		t.Logf("%q: error(s): %v", tt.desc, errs)
		if len(errs) == 0 && tt.wantErrs {
			t.Errorf("%q: ValidateVulnerability(%+v): got success, want error(s)", tt.desc, tt.v)
		}
		if len(errs) > 0 && !tt.wantErrs {
			t.Errorf("%q: ValidateVulnerability(%+v): got error(s) %v, want success", tt.desc, tt.v, errs)
		}
	}
}

func TestValidateVulnerabilityDetail(t *testing.T) {
	tests := []struct {
		desc     string
		vd       *vulnpb.Vulnerability_Detail
		wantErrs bool
	}{
		{
			desc:     "missing CPE URI, want error(s)",
			vd:       &vulnpb.Vulnerability_Detail{},
			wantErrs: true,
		},
		{
			desc: "missing package, want error(s)",
			vd: &vulnpb.Vulnerability_Detail{
				CpeUri: "cpe:/o:debian:debian_linux:7",
			},
			wantErrs: true,
		},
		{
			desc: "invalid min affected version, want error(s)",
			vd: &vulnpb.Vulnerability_Detail{
				CpeUri:             "cpe:/o:debian:debian_linux:7",
				Package:            "debian",
				SeverityName:       "LOW",
				MinAffectedVersion: &pkgpb.Version{},
			},
			wantErrs: true,
		},
		{
			desc: "invalid max affected version, want error(s)",
			vd: &vulnpb.Vulnerability_Detail{
				CpeUri:             "cpe:/o:debian:debian_linux:7",
				Package:            "debian",
				SeverityName:       "LOW",
				MaxAffectedVersion: &pkgpb.Version{},
			},
			wantErrs: true,
		},
		{
			desc: "invalid fixed located set, want error(s)",
			vd: &vulnpb.Vulnerability_Detail{
				CpeUri:        "cpe:/o:debian:debian_linux:7",
				Package:       "debian",
				SeverityName:  "LOW",
				FixedLocation: &vulnpb.VulnerabilityLocation{},
			},
			wantErrs: true,
		},
		{
			desc: "valid vulnerability details, want success",
			vd: &vulnpb.Vulnerability_Detail{
				CpeUri:       "cpe:/o:debian:debian_linux:7",
				Package:      "debian",
				SeverityName: "LOW",
			},
			wantErrs: false,
		},
	}

	for _, tt := range tests {
		errs := validateVulnerabilityDetail(tt.vd)
		t.Logf("%q: error(s): %v", tt.desc, errs)
		if len(errs) == 0 && tt.wantErrs {
			t.Errorf("%q: validateVulnerabilityDetail(%+v): got success, want error(s)", tt.desc, tt.vd)
		}
		if len(errs) > 0 && !tt.wantErrs {
			t.Errorf("%q: validateVulnerabilityDetail(%+v): got error(s) %v, want success", tt.desc, tt.vd, errs)
		}
	}
}

func TestValidateVulnerabilityLocation(t *testing.T) {
	tests := []struct {
		desc     string
		vl       *vulnpb.VulnerabilityLocation
		wantErrs bool
	}{
		{
			desc:     "missing CPE URI, want error(s)",
			vl:       &vulnpb.VulnerabilityLocation{},
			wantErrs: true,
		},
		{
			desc: "missing package, want error(s)",
			vl: &vulnpb.VulnerabilityLocation{
				CpeUri: "cpe:/o:debian:debian_linux:7",
			},
			wantErrs: true,
		},
		{
			desc: "missing version, want error(s)",
			vl: &vulnpb.VulnerabilityLocation{
				CpeUri:  "cpe:/o:debian:debian_linux:7",
				Package: "debian",
			},
			wantErrs: true,
		},
		{
			desc: "version set, but invalid, want error(s)",
			vl: &vulnpb.VulnerabilityLocation{
				CpeUri:  "cpe:/o:debian:debian_linux:7",
				Package: "debian",
				Version: &pkgpb.Version{},
			},
			wantErrs: true,
		},
		{
			desc: "version set and invalid, want success",
			vl: &vulnpb.VulnerabilityLocation{
				CpeUri:  "cpe:/o:debian:debian_linux:7",
				Package: "debian",
				Version: &pkgpb.Version{
					Name: "1.1.2",
					Kind: pkgpb.Version_NORMAL,
				},
			},
			wantErrs: false,
		},
	}

	for _, tt := range tests {
		errs := validateVulnerabilityLocation(tt.vl)
		t.Logf("%q: error(s): %v", tt.desc, errs)
		if len(errs) == 0 && tt.wantErrs {
			t.Errorf("%q: validateVulnerabilityLocation(%+v): got success, want error(s)", tt.desc, tt.vl)
		}
		if len(errs) > 0 && !tt.wantErrs {
			t.Errorf("%q: validateVulnerabilityLocation(%+v): got error(s) %v, want success", tt.desc, tt.vl, errs)
		}
	}
}

func TestValidateDetails(t *testing.T) {
	tests := []struct {
		desc     string
		d        *vulnpb.Details
		wantErrs bool
	}{
		{
			desc:     "missing package issue, want error(s)",
			d:        &vulnpb.Details{},
			wantErrs: true,
		},
		{
			desc: "empty package issue, want error(s)",
			d: &vulnpb.Details{
				PackageIssue: []*vulnpb.PackageIssue{},
			},
			wantErrs: true,
		},
		{
			desc: "nil package issue element, want error(s)",
			d: &vulnpb.Details{
				PackageIssue: []*vulnpb.PackageIssue{nil},
			},
			wantErrs: true,
		},
		{
			desc: "invalid package issue, want error(s)",
			d: &vulnpb.Details{
				PackageIssue: []*vulnpb.PackageIssue{
					{
						AffectedLocation: &vulnpb.VulnerabilityLocation{},
					},
				},
			},
			wantErrs: true,
		},
		{
			desc: "valid details, want success",
			d: &vulnpb.Details{
				PackageIssue: []*vulnpb.PackageIssue{
					{
						AffectedLocation: &vulnpb.VulnerabilityLocation{
							CpeUri:  "cpe:/o:debian:debian_linux:7",
							Package: "debian",
							Version: &pkgpb.Version{
								Name: "1.1.2",
								Kind: pkgpb.Version_NORMAL,
							},
						},
					},
				},
			},
			wantErrs: false,
		},
	}

	for _, tt := range tests {
		errs := ValidateDetails(tt.d)
		t.Logf("%q: error(s): %v", tt.desc, errs)
		if len(errs) == 0 && tt.wantErrs {
			t.Errorf("%q: ValidateDetails(%+v): got success, want error(s)", tt.desc, tt.d)
		}
		if len(errs) > 0 && !tt.wantErrs {
			t.Errorf("%q: ValidateDetails(%+v): got error(s) %v, want success", tt.desc, tt.d, errs)
		}
	}
}

func TestValidatePackageIssue(t *testing.T) {
	tests := []struct {
		desc     string
		p        *vulnpb.PackageIssue
		wantErrs bool
	}{
		{
			desc:     "missing affected location, want error(s)",
			p:        &vulnpb.PackageIssue{},
			wantErrs: true,
		},
		{
			desc: "invalid affected location, want error(s)",
			p: &vulnpb.PackageIssue{
				AffectedLocation: &vulnpb.VulnerabilityLocation{},
			},
			wantErrs: true,
		},
		{
			desc: "invalid fixed location, want error(s)",
			p: &vulnpb.PackageIssue{
				FixedLocation: &vulnpb.VulnerabilityLocation{},
			},
			wantErrs: true,
		},
		{
			desc: "valid package issue, want success",
			p: &vulnpb.PackageIssue{
				AffectedLocation: &vulnpb.VulnerabilityLocation{
					CpeUri:  "cpe:/o:debian:debian_linux:7",
					Package: "debian",
					Version: &pkgpb.Version{
						Name: "1.1.2",
						Kind: pkgpb.Version_NORMAL,
					},
				},
			},
			wantErrs: false,
		},
	}

	for _, tt := range tests {
		errs := validatePackageIssue(tt.p)
		t.Logf("%q: error(s): %v", tt.desc, errs)
		if len(errs) == 0 && tt.wantErrs {
			t.Errorf("%q: validatePackageIssue(%+v): got success, want error(s)", tt.desc, tt.p)
		}
		if len(errs) > 0 && !tt.wantErrs {
			t.Errorf("%q: validatePackageIssue(%+v): got error(s) %v, want success", tt.desc, tt.p, errs)
		}
	}
}
